# -*- coding:utf-8 -*- 
# Scanner module: Scan binaries to get useful infos for exploits 

import subprocess
import mmap
import re
from ropgenerator.Gadget import Gadget
import ropgenerator.Architecture as Arch

# Test elftools 
from elftools.elf.elffile import ELFFile
from elftools.elf.relocation import RelocationSection
from elftools.elf.sections import SymbolTableSection, NullSection

binary_name = None
binary_ELF = None

def initScanner(filename):
    global binary_name
    global binary_ELF
    
    binary_name = filename
    f = open(binary_name, 'rb')
    
    if( Arch.currentIsELF()):
        binary_ELF = ELFFile(f)
    else:
        binary_ELF = None

##################
# Manage offsets #
##################
_offset = 0
def set_offset(offset):
    global _offset
    _offset = offset
    return True
    
def reset_offset():
    global _offset
    _offset=0
    

def getAllFunctions():
    """
    Return a list of all functions for relocation entries
    """
    global _offset 
    # Get functions from relocations     
    relasec_name = '.rela.plt'
    relasec = binary_ELF.get_section_by_name(relasec_name)
    if not isinstance(relasec, RelocationSection):
        return []
    else:
        relasec_addr = relasec.header['sh_addr']
        symbols = binary_ELF.get_section(relasec.header['sh_link'])
        if( not isinstance(symbols, NullSection)): 
            return [(symbols.get_symbol(reloc['r_info_sym']).name, reloc['r_offset']+relasec_addr + _offset) for reloc in relasec.iter_relocations() if reloc.is_RELA()]
    return []

def getSectionAddress(name):
    """
    Returns the address of a segment in the loaded binary 
    """ 
    global binary_ELF, _offset
    section = binary_ELF.get_section_by_name(name)
    if( not section ):
        return None
    return section.header["sh_addr"] + _offset
     
def getSymbolSections():
    global binary_ELF

    if( not Arch.currentIsELF ):
        return []
    return [section for section in binary_ELF.iter_sections() if( isinstance(section, SymbolTableSection))]
    
def getFunctionAddress(name):
    """
    Looks for the function 'name' in the PLT of a binary 
    Returns a pair (name, address) as (str, int)
    """
    global binary_name
    global binary_ELF, _offset
    
    if( not Arch.currentIsELF()):
        return (None, None)
    
    # Get function in relocatins
    relasec_name = '.rela.plt'
    relasec = binary_ELF.get_section_by_name(relasec_name)
    if not isinstance(relasec, RelocationSection):
        return (None,None)
    relasec_addr = relasec.header['sh_addr']
    symbols = binary_ELF.get_section(relasec.header['sh_link'])
    if( not isinstance(symbols, NullSection)): 
        for reloc in relasec.iter_relocations():
            if (symbols.get_symbol(reloc['r_info_sym']).name == name ):
                return (name, reloc['r_offset']+relasec_addr + _offset)
    
    # Get function from symbol table sections  
    for symsec in getSymbolSections():
        function = symsec.get_symbol_by_name(name)
        if( function ):
            return (name, function[0]['st_value'] + _offset)    
    return (None, None)

def findBytes(byte_string, badBytes = [], add_null=False ):
    """
    Parameters
    ----------
    badbytes : bad bytes for substrings addresses
    add_null : if True then add terminaison null bytes in the end of the substrings 
    
    Example: 
        byte_string = 'abc' then result is the address of a string 'abc\x00'
        or a list of addresses s.t elements form 'abc' like 'ab\x00' 'c\x00' 
    """
    
    def _find_substr(m,string):
        if( not string ):
            return [-1,0]
        # Initialize
        offset = -1
        index = len(string)
        substring = string
        # Search biggest substring 
        while( offset == -1 ):
            if( len(substring) <= 0 ):
                return [-1,0]
            offset = m.find(substring)
            if( offset != -1 ):
                return [offset, index]
            else:
                substring = substring[:-1]
            index = index -1
    
    def _find_substr_add_null(m, string):
        if( not string ):
            return [-1,0]
        # Initialize
        offset = -1
        index = len(string)
        last_is_null = (string[-1] == '\x00')
        if( not last_is_null ):
            substring = string + '\x00'
        else:
            substring = string
        # Search biggest substring 
        while( offset == -1 ):
            if( len(substring) <= 0 ):
                return [-1,0]
            offset = m.find(substring)
            if( offset != -1 ):
                return [offset, index]
            else:
                substring = substring[:-2]
                if( not substring ):
                    return [-1,0]
                last_is_null = (substring[-1] == '\x00')    
                if( not last_is_null ):
                    substring = substring + '\x00'
            index = index -1
    
    def _verify_bad_bytes(addr, badBytes):
        addrBytes = re.findall('..',('{:'+'{:02d}'\
            .format(Arch.currentArch.octets)+'x}').format(addr))
        for byte in badBytes:
            if( byte in addrBytes):
                return False
        return True
        
    # Function body 
    global binary_name
    global binary_ELF, _offset
        
    section_names = [".text", '.data']
    # Getting data from all sections
    sections = []
    for section_name in section_names:
        section = binary_ELF.get_section_by_name(section_name)
        if( section.is_null() ):
            continue
        m = section.data()
        addr = section.header['sh_addr']
        sections.append((m, addr))
    if( not sections ):
            return []
            
    # Getting bytes as substrings  
    res = []
    substring = str(byte_string)
    while( substring ):
        found = False
        section_num = 0
        (m, section_addr) = sections[section_num]
        start = 0
        end = len(m)-1
        ## 
        m_tmp = str(m)
        section_changed = False
        while( not found ):
            if( not m_tmp ):
                section_changed = True
                section_num += 1
                
            if( section_num >= len(sections)):
                # Coudln't find substring in any sections 
                return []
                
            if( section_changed ):
                (m, section_addr) = sections[section_num]
                start = 0
                end = len(m)-1
                m_tmp = str(m)
                
            # Get substring address 
            if( add_null ):
                (offset, index ) = _find_substr_add_null(m_tmp, substring)
            else:
                (offset, index ) = _find_substr(m_tmp, substring)
            # We didn't find any match, try next section 
            if( index == 0 ):
                section_num += 1
                section_changed = True
            else:
                section_changed = False
            # Check for bad bytes in the address 
            if( not section_changed ):
                if( _verify_bad_bytes(start+offset, badBytes)):
                    found = True
                else:
                    m_tmp = m_tmp[offset:]
                
        # We add the best substring we found
        if( add_null and substring[:index] != '\x00'):
            res.append([offset+section_addr+_offset,substring[:index]+"\x00"])
        else:
            res.append([offset+section_addr+_offset,substring[:index]])
        substring = substring[index:]
        section_num = 0
    
    return res
